{
 "cells": [
   {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/yuval-alaluf/restyle-encoder/blob/main/notebooks/inference_playground.ipynb\"><img align=\"left\" alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>",
    "<a href=\"https://colab.research.google.com/github/yuval-alaluf/restyle-encoder/blob/main/notebooks/inference_playground.ipynb\"><img align=\"left\" title=\"Open in Colab\" src=\"https://colab.research.google.com/assets/colab-badge.svg\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Uuviq3qQkUFy"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/content')\n",
    "CODE_DIR = 'restyle-encoder'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QQ6XEmlHlXbk"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/yuval-alaluf/restyle-encoder.git $CODE_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JaRUFuVHkzye"
   },
   "outputs": [],
   "source": [
    "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
    "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
    "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "23baccYQlU9E"
   },
   "outputs": [],
   "source": [
    "os.chdir(f'./{CODE_DIR}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d13v7In0kTJn"
   },
   "outputs": [],
   "source": [
    "from argparse import Namespace\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import pprint\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "sys.path.append(\".\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from utils.common import tensor2im\n",
    "from models.psp import pSp\n",
    "from models.e4e import e4e\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HRjtz6uLkTJs"
   },
   "source": [
    "## Step 1: Select Experiment Type\n",
    "Select which experiment you wish to perform inference on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XESWAO65kTJt"
   },
   "outputs": [],
   "source": [
    "#@title Select which experiment you wish to perform inference on: { run: \"auto\" }\n",
    "experiment_type = 'ffhq_encode' #@param ['ffhq_encode', 'cars_encode', 'church_encode', 'horse_encode', 'afhq_wild_encode', 'toonify']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4etDz82xkTJz"
   },
   "source": [
    "## Step 2: Prepare to Download Pretrained Models \n",
    "As part of this repository, we provide pretrained models for each of the above experiments. Here, we'll create the download command needed for downloading the desired model.\n",
    "\n",
    "Note: in this notebook, we'll be using ReStyle applied over pSp for all domains except for the horses domain where we'll be using e4e. This is done since e4e is generally able to generate more realistic reconstructions on this domain. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KSnjlBZOkTJ0"
   },
   "outputs": [],
   "source": [
    "def get_download_model_command(file_id, file_name):\n",
    "    \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n",
    "    current_directory = os.getcwd()\n",
    "    save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, \"pretrained_models\")\n",
    "    if not os.path.exists(save_path):\n",
    "        os.makedirs(save_path)\n",
    "    url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n",
    "    return url    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m4sjldFMkTJ5"
   },
   "outputs": [],
   "source": [
    "MODEL_PATHS = {\n",
    "    \"ffhq_encode\": {\"id\": \"1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE\", \"name\": \"restyle_psp_ffhq_encode.pt\"},\n",
    "    \"cars_encode\": {\"id\": \"1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR\", \"name\": \"restyle_psp_cars_encode.pt\"},\n",
    "    \"church_encode\": {\"id\": \"1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD\", \"name\": \"restyle_psp_church_encode.pt\"},\n",
    "    \"horse_encode\": {\"id\": \"19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY\", \"name\": \"restyle_e4e_horse_encode.pt\"},\n",
    "    \"afhq_wild_encode\": {\"id\": \"1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7\", \"name\": \"restyle_psp_afhq_wild_encode.pt\"},\n",
    "    \"toonify\": {\"id\": \"1GtudVDig59d4HJ_8bGEniz5huaTSGO_0\", \"name\": \"restyle_psp_toonify.pt\"}\n",
    "}\n",
    "\n",
    "path = MODEL_PATHS[experiment_type]\n",
    "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9Tozsg81kTKA"
   },
   "source": [
    "## Step 3: Define Inference Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XIhyc7RqkTKB"
   },
   "source": [
    "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the image to perform inference on.  \n",
    "While we provide default values to run this script, feel free to change as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2kE5y1-skTKC"
   },
   "outputs": [],
   "source": [
    "EXPERIMENT_DATA_ARGS = {\n",
    "    \"ffhq_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_ffhq_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/face_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"cars_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_cars_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/car_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((192, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"church_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_church_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/church_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"horse_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_e4e_horse_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/horse_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"afhq_wild_encode\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_afhq_wild_encode.pt\",\n",
    "        \"image_path\": \"notebooks/images/afhq_wild_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "    \"toonify\": {\n",
    "        \"model_path\": \"pretrained_models/restyle_psp_toonify.pt\",\n",
    "        \"image_path\": \"notebooks/images/toonify_img.jpg\",\n",
    "        \"transform\": transforms.Compose([\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IzUHoD9ukTKG"
   },
   "outputs": [],
   "source": [
    "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To reduce the number of requests to fetch the model, we'll check if the model was previously downloaded and saved before downloading the model.  \n",
    "We'll download the model for the selected experiment and save it to the folder `../pretrained_models`.\n",
    "\n",
    "We also need to verify that the model was downloaded correctly. All of our models should weigh approximately 800MB - 1GB.  \n",
    "Note that if the file weighs several KBs, you most likely encounter a \"quota exceeded\" error from Google Drive. In that case, you should try downloading the model again after a few hours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jQ31J_m7kTJ8"
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n",
    "    print(f'Downloading ReStyle model for {experiment_type}...')\n",
    "    os.system(f\"wget {download_command}\")\n",
    "    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model\n",
    "    if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n",
    "        raise ValueError(\"Pretrained model was unable to be downloaded correctly!\")\n",
    "    else:\n",
    "        print('Done.')\n",
    "else:\n",
    "    print(f'ReStyle model for {experiment_type} already exists!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TAWrUehTkTKJ"
   },
   "source": [
    "## Step 4: Load Pretrained Model\n",
    "We assume that you have downloaded all relevant models and placed them in the directory defined by the above dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1t-AOhP1kTKJ"
   },
   "outputs": [],
   "source": [
    "model_path = EXPERIMENT_ARGS['model_path']\n",
    "ckpt = torch.load(model_path, map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2UBwJ3dJkTKM"
   },
   "outputs": [],
   "source": [
    "opts = ckpt['opts']\n",
    "pprint.pprint(opts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EMKhWoFKkTKS"
   },
   "outputs": [],
   "source": [
    "# update the training options\n",
    "opts['checkpoint_path'] = model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6hccfNizkTKW"
   },
   "outputs": [],
   "source": [
    "opts = Namespace(**opts)\n",
    "if experiment_type == 'horse_encode': \n",
    "    net = e4e(opts)\n",
    "else:\n",
    "    net = pSp(opts)\n",
    "    \n",
    "net.eval()\n",
    "net.cuda()\n",
    "print('Model successfully loaded!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4weLFoPbkTKZ"
   },
   "source": [
    "## Step 5: Visualize Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r2H9zFLJkTKa"
   },
   "outputs": [],
   "source": [
    "image_path = EXPERIMENT_DATA_ARGS[experiment_type][\"image_path\"]\n",
    "original_image = Image.open(image_path).convert(\"RGB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-lbLKtl-kTKc"
   },
   "outputs": [],
   "source": [
    "if experiment_type == 'cars_encode':\n",
    "    original_image = original_image.resize((192, 256))\n",
    "else:\n",
    "    original_image = original_image.resize((256, 256))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o6oqf8JwzK0K"
   },
   "source": [
    "### Align Image\n",
    "\n",
    "Note: in this notebook we'll run alignment on the input image when working on the human facial domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hJ9Ce1aYzmFF"
   },
   "outputs": [],
   "source": [
    "def run_alignment(image_path):\n",
    "    import dlib\n",
    "    from scripts.align_faces_parallel import align_face\n",
    "    if not os.path.exists(\"shape_predictor_68_face_landmarks.dat\"):\n",
    "        print('Downloading files for aligning face image...')\n",
    "        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        print('Done.')\n",
    "    predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n",
    "    aligned_image = align_face(filepath=image_path, predictor=predictor) \n",
    "    print(\"Aligned image has shape: {}\".format(aligned_image.size))\n",
    "    return aligned_image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aTZcKMdK8y77"
   },
   "outputs": [],
   "source": [
    "if experiment_type in ['ffhq_encode', 'toonify']:\n",
    "    input_image = run_alignment(image_path)\n",
    "else:\n",
    "    input_image = original_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hUBAfodh5PaM"
   },
   "outputs": [],
   "source": [
    "input_image.resize((256, 256))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D0BmXzu1kTKg"
   },
   "source": [
    "## Step 6: Perform Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "T3h3E7VLkTKg"
   },
   "outputs": [],
   "source": [
    "img_transforms = EXPERIMENT_ARGS['transform']\n",
    "transformed_image = img_transforms(input_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_fNBlRU8OSDL"
   },
   "source": [
    "Before running inference, we need to generate the image corresponding to the average latent code. These will be used to initialize the iterative refinement process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fmpzoODNOSDL"
   },
   "outputs": [],
   "source": [
    "def get_avg_image(net):\n",
    "    avg_image = net(net.latent_avg.unsqueeze(0),\n",
    "                    input_code=True,\n",
    "                    randomize_noise=False,\n",
    "                    return_latents=False,\n",
    "                    average_code=True)[0]\n",
    "    avg_image = avg_image.to('cuda').float().detach()\n",
    "    if experiment_type == \"cars_encode\":\n",
    "        avg_image = avg_image[:, 32:224, :]\n",
    "    return avg_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M5eWR2S4OSDM"
   },
   "source": [
    "Now we'll run inference. By default, we'll run using 5 inference steps. You can change the parameter in the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ct_jm0obOSDM"
   },
   "outputs": [],
   "source": [
    "opts.n_iters_per_batch = 5\n",
    "opts.resize_outputs = False  # generate outputs at full resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ls5zb0fRkTKs"
   },
   "outputs": [],
   "source": [
    "from utils.inference_utils import run_on_batch\n",
    "\n",
    "with torch.no_grad():\n",
    "    avg_image = get_avg_image(net)\n",
    "    tic = time.time()\n",
    "    result_batch, result_latents = run_on_batch(transformed_image.unsqueeze(0).cuda(), net, opts, avg_image)\n",
    "    toc = time.time()\n",
    "    print('Inference took {:.4f} seconds.'.format(toc - tic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nq0dkSz6kTKv"
   },
   "source": [
    "### Visualize Result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UVR03XT_kTK0"
   },
   "source": [
    "We'll visualize the step-by-step outputs side by side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ca5BtxdUOSDN"
   },
   "outputs": [],
   "source": [
    "if opts.dataset_type == \"cars_encode\":\n",
    "    resize_amount = (256, 192) if opts.resize_outputs else (512, 384)\n",
    "else:\n",
    "    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WdR51hOROSDN"
   },
   "outputs": [],
   "source": [
    "def get_coupled_results(result_batch, transformed_image):\n",
    "    \"\"\"\n",
    "    Visualize output images from left to right (the input image is on the right)\n",
    "    \"\"\"\n",
    "    result_tensors = result_batch[0]  # there's one image in our batch\n",
    "    result_images = [tensor2im(result_tensors[iter_idx]) for iter_idx in range(opts.n_iters_per_batch)]\n",
    "    input_im = tensor2im(transformed_image)\n",
    "    res = np.array(result_images[0].resize(resize_amount))\n",
    "    for idx, result in enumerate(result_images[1:]):\n",
    "        res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)\n",
    "    res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)\n",
    "    res = Image.fromarray(res)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uSDCvtTMOSDN"
   },
   "source": [
    "Note that the step-by-step outputs are shown left-to-right with the original input on the right-hand side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lb3raAKFOSDN"
   },
   "outputs": [],
   "source": [
    "res = get_coupled_results(result_batch, transformed_image)\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qaB7RN7cOSDN"
   },
   "outputs": [],
   "source": [
    "# save image \n",
    "res.save(f'./{experiment_type}_results.jpg')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ISEMFxmekTK7"
   },
   "source": [
    "# Encoder Bootstrapping\n",
    "\n",
    "In the paper, we introduce an encoder bootstrapping technique that can be used to solve the image toonification task by pairing an FFHQ-based encoder with a Toon-based encoder.  \n",
    "\n",
    "We demonstrate this idea below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Sv284Ox8OSDO"
   },
   "outputs": [],
   "source": [
    "# download the ffhq-based encoder if not previously downloaded\n",
    "path = MODEL_PATHS['ffhq_encode']\n",
    "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS['ffhq_encode']\n",
    "ffhq_model_path = EXPERIMENT_ARGS['model_path']\n",
    "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) \n",
    "if not os.path.exists(ffhq_model_path):\n",
    "    print('Downloading FFHQ ReStyle encoder...')\n",
    "    os.system(f\"wget {download_command}\")\n",
    "    print('Done.')\n",
    "else:\n",
    "    print('FFHQ ReStyle encoder already exists!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FKbAFK7_OSDO"
   },
   "outputs": [],
   "source": [
    "# download the toon-based encoder if not previously downloaded\n",
    "path = MODEL_PATHS['toonify']\n",
    "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS['toonify']\n",
    "toonify_model_path = EXPERIMENT_ARGS['model_path']\n",
    "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) \n",
    "if not os.path.exists(toonify_model_path):\n",
    "    print('Downloading Toonify ReStyle encoder...')\n",
    "    os.system(f\"wget {download_command}\")\n",
    "    print('Done.')\n",
    "else:\n",
    "    print('Toonify ReStyle encoder already exists!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K3v0X3ZWkTK8"
   },
   "outputs": [],
   "source": [
    "# load models \n",
    "ckpt = torch.load(ffhq_model_path, map_location='cpu')\n",
    "opts = ckpt['opts']\n",
    "opts['checkpoint_path'] = ffhq_model_path\n",
    "opts = Namespace(**opts)\n",
    "net1 = pSp(opts)\n",
    "net1.eval()\n",
    "net1.cuda()\n",
    "print('FFHQ Model successfully loaded!')\n",
    "\n",
    "ckpt = torch.load(toonify_model_path, map_location='cpu')\n",
    "opts = ckpt['opts']\n",
    "opts['checkpoint_path'] = toonify_model_path\n",
    "opts = Namespace(**opts)\n",
    "net2 = pSp(opts)\n",
    "net2.eval()\n",
    "net2.cuda()\n",
    "print('Toonify Model successfully loaded!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XW-CJsuwOSDO"
   },
   "outputs": [],
   "source": [
    "# load image \n",
    "image_path = EXPERIMENT_DATA_ARGS['toonify'][\"image_path\"]\n",
    "original_image = Image.open(image_path).convert(\"RGB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MmPWPODaOSDP"
   },
   "outputs": [],
   "source": [
    "# transform image\n",
    "img_transforms = EXPERIMENT_ARGS['transform']\n",
    "transformed_image = img_transforms(original_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BiMjTyMzOSDP"
   },
   "outputs": [],
   "source": [
    "opts.n_iters_per_batch = 5\n",
    "opts.resize_outputs = False  # generate outputs at full resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o81i-MtOOSDQ"
   },
   "outputs": [],
   "source": [
    "from scripts.encoder_bootstrapping_inference import run_on_batch\n",
    "\n",
    "with torch.no_grad():\n",
    "    avg_image = get_avg_image(net1)\n",
    "    tic = time.time()\n",
    "    result_batch = run_on_batch(transformed_image.unsqueeze(0).cuda(), net1, net2, opts, avg_image)\n",
    "    toc = time.time()\n",
    "    print('Inference took {:.4f} seconds.'.format(toc - tic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1AGWnm9BOSDQ"
   },
   "source": [
    "Again we'll visualize the results from left to right. Here, the leftmost image is the inverted FFHQ image that is used to initialize the toonify ReStyle encoder. The following images show iterative results outputted by the toonify model.\n",
    "Finally, the rightmost image is the original input image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FX-_45rxOSDQ"
   },
   "outputs": [],
   "source": [
    "res = get_coupled_results(result_batch, transformed_image)\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Pdk_QLRFOSDQ"
   },
   "outputs": [],
   "source": [
    "# save image \n",
    "res.save(f'./encoder_bootstrapping_results.jpg')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "inference_playground.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
